import numpy
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.optimizers import RMSprop

from tensorflowonspark import TFNode


def main_fun(args, ctx):
  IMAGE_PIXELS = 28
  num_classes = 10

  # use Keras API to load data
  from tensorflow.python.keras.datasets import mnist
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_train = x_train.reshape(60000, 784)
  x_test = x_test.reshape(10000, 784)
  x_train = x_train.astype('float32') / 255
  x_test = x_test.astype('float32') / 255

  # convert class vectors to binary class matrices
  y_train = keras.utils.to_categorical(y_train, num_classes)
  y_test = keras.utils.to_categorical(y_test, num_classes)

  # setup a Keras model
  model = Sequential()
  model.add(Dense(512, activation='relu', input_shape=(784,)))
  model.add(Dropout(0.2))
  model.add(Dense(512, activation='relu'))
  model.add(Dropout(0.2))
  model.add(Dense(10, activation='softmax'))
  model.compile(loss='categorical_crossentropy',
                optimizer=RMSprop(),
                metrics=['accuracy'])
  model.summary()

  # convert Keras model to tf.estimator
  estimator = tf.keras.estimator.model_to_estimator(model, model_dir=args.model_dir)

  # setup train_input_fn for InputMode.TENSORFLOW or InputMode.SPARK
  if args.mode == 'train':
    if args.input_mode == 'tf':
      # For InputMode.TENSORFLOW, just use data in memory
      train_input_fn = tf.estimator.inputs.numpy_input_fn(
          x={"dense_1_input": x_train},
          y=y_train,
          batch_size=128,
          num_epochs=None,
          shuffle=True)
    else:  # 'spark'
      # For InputMode.SPARK, read data from RDD
      tf_feed = TFNode.DataFeed(ctx.mgr)

      def rdd_generator():
        while not tf_feed.should_stop():
          batch = tf_feed.next_batch(1)
          if len(batch) > 0:
            record = batch[0]
            image = numpy.array(record[0]).astype(numpy.float32) / 255.0
            label = numpy.array(record[1]).astype(numpy.float32)
            yield (image, label)

      def train_input_fn():
        ds = tf.data.Dataset.from_generator(rdd_generator,
                                            (tf.float32, tf.float32),
                                            (tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
        ds = ds.batch(args.batch_size)
        return ds

    # eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"dense_1_input": x_test},
        y=y_test,
        num_epochs=args.epochs,
        shuffle=False)

    # serving_input_receiver_fn ALWAYS expects serialized TFExamples in a placeholder.
    def serving_input_receiver_fn():
      """An input receiver that expects a serialized tf.Example."""
      serialized_tf_example = tf.placeholder(dtype=tf.string,
                                             shape=[args.batch_size],
                                             name='input_example_tensor')
      receiver_tensors = {'dense_1_input': serialized_tf_example}
      feature_spec = {'dense_1_input': tf.FixedLenFeature(784, tf.string)}
      features = tf.parse_example(serialized_tf_example, feature_spec)
      return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

    # setup tf.estimator.train_and_evaluate() w/ FinalExporter
    exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn)
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=args.steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, exporters=exporter)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

  else:  # mode == 'inference'
    if args.input_mode == 'spark':
      tf_feed = TFNode.DataFeed(ctx.mgr)

      def rdd_generator():
        while not tf_feed.should_stop():
          batch = tf_feed.next_batch(1)
          if len(batch) > 0:
            record = batch[0]
            image = numpy.array(record[0]).astype(numpy.float32) / 255.0
            label = numpy.array(record[1]).astype(numpy.float32)
            yield (image, label)

      def predict_input_fn():
        ds = tf.data.Dataset.from_generator(rdd_generator,
                                            (tf.float32, tf.float32),
                                            (tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
        ds = ds.batch(args.batch_size)
        return ds

      predictions = estimator.predict(predict_input_fn)
      for result in predictions:
        tf_feed.batch_results([result])


if __name__ == '__main__':
  import argparse
  from pyspark.context import SparkContext
  from pyspark.conf import SparkConf
  from tensorflowonspark import TFCluster

  sc = SparkContext(conf=SparkConf().setAppName("mnist_mlp"))
  executors = sc._conf.get("spark.executor.instances")
  num_executors = int(executors) if executors is not None else 1
  num_ps = 1

  parser = argparse.ArgumentParser()
  parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
  parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
  parser.add_argument("--epochs", help="number of epochs of training data", type=int, default=1)
  parser.add_argument("--export_dir", help="directory to export saved_model")
  parser.add_argument("--images", help="HDFS path to MNIST images in parallelized CSV format")
  parser.add_argument("--input_mode", help="input mode (tf|spark)", default="tf")
  parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized CSV format")
  parser.add_argument("--model_dir", help="directory to write model checkpoints")
  parser.add_argument("--mode", help="(train|inference")
  parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
  parser.add_argument("--num_ps", help="number of ps nodes", type=int, default=1)
  parser.add_argument("--steps", help="max number of steps to train", type=int, default=2000)
  parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")

  args = parser.parse_args()
  print("args:", args)

  if args.input_mode == 'tf':
    # for TENSORFLOW mode, each node will load/train/infer entire dataset in memory per original example
    cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='master')
    cluster.shutdown()
  else:  # 'spark'
    # for SPARK mode, just use CSV format as an example
    images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
    labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
    dataRDD = images.zip(labels)
    if args.mode == 'train':
      cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
      cluster.train(dataRDD, args.epochs)
      cluster.shutdown()
    else:
      # Note: using "parallel" inferencing, not "cluster"
      # each node loads the model and runs independently of others
      cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, 0, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
      resultRDD = cluster.inference(dataRDD)
      resultRDD.saveAsTextFile(args.output)
      cluster.shutdown()
